{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.13.1\n",
      "sys.version_info(major=3, minor=7, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.0.3\n",
      "numpy 1.16.2\n",
      "pandas 0.24.2\n",
      "sklearn 0.20.3\n",
      "tensorflow 1.13.1\n",
      "tensorflow._api.v1.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, tf, keras:\n",
    "    print(module.__name__, module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   survived     sex   age  n_siblings_spouses  parch     fare  class     deck  \\\n",
      "0         0    male  22.0                   1      0   7.2500  Third  unknown   \n",
      "1         1  female  38.0                   1      0  71.2833  First        C   \n",
      "2         1  female  26.0                   0      0   7.9250  Third  unknown   \n",
      "3         1  female  35.0                   1      0  53.1000  First        C   \n",
      "4         0    male  28.0                   0      0   8.4583  Third  unknown   \n",
      "\n",
      "   embark_town alone  \n",
      "0  Southampton     n  \n",
      "1    Cherbourg     n  \n",
      "2  Southampton     y  \n",
      "3  Southampton     n  \n",
      "4   Queenstown     y  \n",
      "   survived     sex   age  n_siblings_spouses  parch     fare   class  \\\n",
      "0         0    male  35.0                   0      0   8.0500   Third   \n",
      "1         0    male  54.0                   0      0  51.8625   First   \n",
      "2         1  female  58.0                   0      0  26.5500   First   \n",
      "3         1  female  55.0                   0      0  16.0000  Second   \n",
      "4         1    male  34.0                   0      0  13.0000  Second   \n",
      "\n",
      "      deck  embark_town alone  \n",
      "0  unknown  Southampton     y  \n",
      "1        E  Southampton     y  \n",
      "2        C  Southampton     y  \n",
      "3  unknown  Southampton     y  \n",
      "4        D  Southampton     y  \n"
     ]
    }
   ],
   "source": [
    "# https://storage.googleapis.com/tf-datasets/titanic/train.csv\n",
    "# https://storage.googleapis.com/tf-datasets/titanic/eval.csv\n",
    "train_file = \"./data/titanic/train.csv\"\n",
    "eval_file = \"./data/titanic/eval.csv\"\n",
    "\n",
    "train_df = pd.read_csv(train_file)\n",
    "eval_df = pd.read_csv(eval_file)\n",
    "\n",
    "print(train_df.head())\n",
    "print(eval_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      sex   age  n_siblings_spouses  parch     fare  class     deck  \\\n",
      "0    male  22.0                   1      0   7.2500  Third  unknown   \n",
      "1  female  38.0                   1      0  71.2833  First        C   \n",
      "2  female  26.0                   0      0   7.9250  Third  unknown   \n",
      "3  female  35.0                   1      0  53.1000  First        C   \n",
      "4    male  28.0                   0      0   8.4583  Third  unknown   \n",
      "\n",
      "   embark_town alone  \n",
      "0  Southampton     n  \n",
      "1    Cherbourg     n  \n",
      "2  Southampton     y  \n",
      "3  Southampton     n  \n",
      "4   Queenstown     y  \n",
      "      sex   age  n_siblings_spouses  parch     fare   class     deck  \\\n",
      "0    male  35.0                   0      0   8.0500   Third  unknown   \n",
      "1    male  54.0                   0      0  51.8625   First        E   \n",
      "2  female  58.0                   0      0  26.5500   First        C   \n",
      "3  female  55.0                   0      0  16.0000  Second  unknown   \n",
      "4    male  34.0                   0      0  13.0000  Second        D   \n",
      "\n",
      "   embark_town alone  \n",
      "0  Southampton     y  \n",
      "1  Southampton     y  \n",
      "2  Southampton     y  \n",
      "3  Southampton     y  \n",
      "4  Southampton     y  \n",
      "0    0\n",
      "1    1\n",
      "2    1\n",
      "3    1\n",
      "4    0\n",
      "Name: survived, dtype: int64\n",
      "0    0\n",
      "1    0\n",
      "2    1\n",
      "3    1\n",
      "4    1\n",
      "Name: survived, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "y_train = train_df.pop('survived')\n",
    "y_eval = eval_df.pop('survived')\n",
    "\n",
    "print(train_df.head())\n",
    "print(eval_df.head())\n",
    "print(y_train.head())\n",
    "print(y_eval.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>n_siblings_spouses</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>627.000000</td>\n",
       "      <td>627.000000</td>\n",
       "      <td>627.000000</td>\n",
       "      <td>627.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>29.631308</td>\n",
       "      <td>0.545455</td>\n",
       "      <td>0.379585</td>\n",
       "      <td>34.385399</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>12.511818</td>\n",
       "      <td>1.151090</td>\n",
       "      <td>0.792999</td>\n",
       "      <td>54.597730</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.750000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>23.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>7.895800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>28.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>15.045800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>35.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>31.387500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>80.000000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>512.329200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              age  n_siblings_spouses       parch        fare\n",
       "count  627.000000          627.000000  627.000000  627.000000\n",
       "mean    29.631308            0.545455    0.379585   34.385399\n",
       "std     12.511818            1.151090    0.792999   54.597730\n",
       "min      0.750000            0.000000    0.000000    0.000000\n",
       "25%     23.000000            0.000000    0.000000    7.895800\n",
       "50%     28.000000            0.000000    0.000000   15.045800\n",
       "75%     35.000000            1.000000    0.000000   31.387500\n",
       "max     80.000000            8.000000    5.000000  512.329200"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sex ['male' 'female']\n",
      "n_siblings_spouses [1 0 3 4 2 5 8]\n",
      "parch [0 1 2 5 3 4]\n",
      "class ['Third' 'First' 'Second']\n",
      "deck ['unknown' 'C' 'G' 'A' 'B' 'D' 'F' 'E']\n",
      "embark_town ['Southampton' 'Cherbourg' 'Queenstown' 'unknown']\n",
      "alone ['n' 'y']\n"
     ]
    }
   ],
   "source": [
    "categorical_columns = ['sex', 'n_siblings_spouses', 'parch', 'class',\n",
    "                       'deck', 'embark_town', 'alone']\n",
    "numeric_columns = ['age', 'fare']\n",
    "\n",
    "feature_columns = []\n",
    "for categorical_column in categorical_columns:\n",
    "    vocab = train_df[categorical_column].unique()\n",
    "    print(categorical_column, vocab)\n",
    "    feature_columns.append(\n",
    "        tf.feature_column.indicator_column(\n",
    "            tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "                categorical_column, vocab)))\n",
    "\n",
    "for categorical_column in numeric_columns:\n",
    "    feature_columns.append(\n",
    "        tf.feature_column.numeric_column(\n",
    "            categorical_column, dtype=tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_dataset(data_df, label_df, epochs = 10, shuffle = True,\n",
    "                 batch_size = 32):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices(\n",
    "        (dict(data_df), label_df))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size)\n",
    "    return dataset.make_one_shot_iterator().get_next()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "INFO:tensorflow:Using config: {'_model_dir': 'customized_estimator', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x12daf3080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column.py:205: NumericColumn._get_dense_tensor (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column.py:2121: NumericColumn._transform_feature (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2703: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column.py:206: NumericColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column.py:205: IndicatorColumn._get_dense_tensor (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column.py:2121: IndicatorColumn._transform_feature (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4295: VocabularyListCategoricalColumn._get_sparse_tensors (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column.py:2121: VocabularyListCategoricalColumn._transform_feature (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/ops/lookup_ops.py:1137: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4266: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4321: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
      "WARNING:tensorflow:From <ipython-input-7-c57751d6dcfc>:12: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 0 into customized_estimator/model.ckpt.\n",
      "INFO:tensorflow:loss = 1.9022088, step = 1\n",
      "INFO:tensorflow:global_step/sec: 301.616\n",
      "INFO:tensorflow:loss = 0.29633158, step = 101 (0.341 sec)\n",
      "INFO:tensorflow:global_step/sec: 505.298\n",
      "INFO:tensorflow:loss = 0.6143039, step = 201 (0.189 sec)\n",
      "INFO:tensorflow:global_step/sec: 564.455\n",
      "INFO:tensorflow:loss = 0.32041603, step = 301 (0.177 sec)\n",
      "INFO:tensorflow:global_step/sec: 563.953\n",
      "INFO:tensorflow:loss = 0.40293896, step = 401 (0.177 sec)\n",
      "INFO:tensorflow:global_step/sec: 583.142\n",
      "INFO:tensorflow:loss = 0.3103894, step = 501 (0.172 sec)\n",
      "INFO:tensorflow:global_step/sec: 529.033\n",
      "INFO:tensorflow:loss = 0.5766603, step = 601 (0.189 sec)\n",
      "INFO:tensorflow:global_step/sec: 455.618\n",
      "INFO:tensorflow:loss = 0.34097248, step = 701 (0.219 sec)\n",
      "INFO:tensorflow:global_step/sec: 473.415\n",
      "INFO:tensorflow:loss = 0.25038648, step = 801 (0.212 sec)\n",
      "INFO:tensorflow:global_step/sec: 509.263\n",
      "INFO:tensorflow:loss = 0.505088, step = 901 (0.196 sec)\n",
      "INFO:tensorflow:global_step/sec: 465.56\n",
      "INFO:tensorflow:loss = 0.44969255, step = 1001 (0.216 sec)\n",
      "INFO:tensorflow:global_step/sec: 443.408\n",
      "INFO:tensorflow:loss = 0.24732417, step = 1101 (0.225 sec)\n",
      "INFO:tensorflow:global_step/sec: 427.855\n",
      "INFO:tensorflow:loss = 0.34562913, step = 1201 (0.235 sec)\n",
      "INFO:tensorflow:global_step/sec: 446.927\n",
      "INFO:tensorflow:loss = 0.27141798, step = 1301 (0.224 sec)\n",
      "INFO:tensorflow:global_step/sec: 455.317\n",
      "INFO:tensorflow:loss = 0.6561841, step = 1401 (0.218 sec)\n",
      "INFO:tensorflow:global_step/sec: 516.836\n",
      "INFO:tensorflow:loss = 0.40642688, step = 1501 (0.194 sec)\n",
      "INFO:tensorflow:global_step/sec: 529.462\n",
      "INFO:tensorflow:loss = 0.32950634, step = 1601 (0.189 sec)\n",
      "INFO:tensorflow:global_step/sec: 515.788\n",
      "INFO:tensorflow:loss = 0.47663295, step = 1701 (0.194 sec)\n",
      "INFO:tensorflow:global_step/sec: 526.901\n",
      "INFO:tensorflow:loss = 0.4040637, step = 1801 (0.190 sec)\n",
      "INFO:tensorflow:global_step/sec: 621.311\n",
      "INFO:tensorflow:loss = 0.34504694, step = 1901 (0.161 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 1960 into customized_estimator/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.15136214.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow_estimator.python.estimator.estimator.Estimator at 0x12da60ef0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_dir = \"customized_estimator\"\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "\n",
    "def model_fn(features, labels, mode, params):\n",
    "    # model runtime state: Train, Eval, Predict\n",
    "    input_for_next_layer = tf.feature_column.input_layer(\n",
    "        features, params['feature_columns'])\n",
    "    for n_unit in params['hidden_units']:\n",
    "        input_for_next_layer = tf.layers.dense(input_for_next_layer,\n",
    "                                               units = n_unit,\n",
    "                                               activation = tf.nn.relu)\n",
    "    logits = tf.layers.dense(input_for_next_layer,\n",
    "                             params['n_classes'],\n",
    "                             activation = None)\n",
    "    predicted_classes = tf.argmax(logits, 1)\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        predictions = {\n",
    "            \"class_ids\": predicted_classes[:, tf.newaxis],\n",
    "            \"probabilities\": tf.nn.softmax(logits),\n",
    "            \"logits\": logits\n",
    "        }\n",
    "        return tf.estimator.EstimatorSpec(mode,\n",
    "                                          predictions = predictions)\n",
    "    \n",
    "    loss = tf.losses.sparse_softmax_cross_entropy(labels = labels,\n",
    "                                                  logits = logits)\n",
    "    accuracy = tf.metrics.accuracy(labels = labels,\n",
    "                                   predictions = predicted_classes,\n",
    "                                   name = \"acc_op\")\n",
    "    metrics = {\"accuracy\": accuracy}\n",
    "    if mode == tf.estimator.ModeKeys.EVAL:\n",
    "        return tf.estimator.EstimatorSpec(mode, loss = loss,\n",
    "                                          eval_metric_ops = metrics)\n",
    "    optimizer = tf.train.AdamOptimizer()\n",
    "    train_op = optimizer.minimize(\n",
    "        loss, global_step = tf.train.get_global_step())\n",
    "    return tf.estimator.EstimatorSpec(mode, loss = loss,\n",
    "                                      train_op = train_op)\n",
    "\n",
    "estimator = tf.estimator.Estimator(\n",
    "    model_fn = model_fn,\n",
    "    model_dir = output_dir,\n",
    "    params = {\n",
    "        \"feature_columns\": feature_columns,\n",
    "        \"hidden_units\": [100, 100],\n",
    "        \"n_classes\": 2\n",
    "    })\n",
    "estimator.train(input_fn = lambda : make_dataset(\n",
    "    train_df, y_train, epochs = 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Starting evaluation at 2019-06-12T13:50:23Z\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "WARNING:tensorflow:From /Users/zhangyx/workspace/environments/tf_py3/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use standard file APIs to check for files with this prefix.\n",
      "INFO:tensorflow:Restoring parameters from customized_estimator/model.ckpt-1960\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Finished evaluation at 2019-06-12-13:50:23\n",
      "INFO:tensorflow:Saving dict for global step 1960: accuracy = 0.7916667, global_step = 1960, loss = 0.5161865\n",
      "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1960: customized_estimator/model.ckpt-1960\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.7916667, 'loss': 0.5161865, 'global_step': 1960}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "estimator.evaluate(lambda : make_dataset(\n",
    "    eval_df, y_eval, epochs = 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
